import os
import sys
import time
import subprocess
from os import path
import itertools
from datetime import datetime
import argparse
import json
import numpy as np
import socket

import jax

import wandb

from pprint import pprint

from pathlib import Path

# from bb_mbrl.utils import Configparser, ROOT_PATH, print_processes, Discord_Notification

class bcolors:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKCYAN = '\033[96m'
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'


colors = {
    'waiting': bcolors.ENDC,
    'running': bcolors.OKCYAN,
    'finished': bcolors.OKGREEN,
    'killed': bcolors.FAIL,
    'died': bcolors.FAIL
}


def print_processes(cmds, clear=False):
    now = datetime.now()

    def print_process(cmd):
        state = cmd['state']
        color = colors[state.lower()]
        # out = f"[{now:%H:%M:%S}] " \
        out = f"Job({cmd['job_idx']:03d}/{cmd['n_total']:03d}) {color}{state:9}{bcolors.ENDC} d={str(cmd['gpu']):6}"
        if cmd['start_time'] is not None:
            start = datetime.fromtimestamp(cmd['start_time'])
            end = datetime.fromtimestamp(cmd['finish_time']) if cmd['finish_time'] is not None else datetime.now()
            out += f" {(str(end - start)).split('.')[0]}"
        # out += f" {cmd['group_name']}"
        print(out, flush=True)
        return out

    if clear:
        _prev_line = '\033[F'
        _clear_line = ' ' * 100
        _skip_lines = len(cmds) + 1
        position = _prev_line * _skip_lines
        empty = '\n'.join([_clear_line for _ in range(_skip_lines)])
        print(position, end='')
        print(empty)
        print(position, end='')

    full_string = ''
    print()
    for c in cmds:
        full_string += print_process(c) + '\n'

    return full_string


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("--experiment", dest="experiment",   type=str, required=True,  default="", help="Python script of the experiment.")
    parser.add_argument("--config",     dest="config", nargs='+',  type=str, required=True, help="Hyperparameter Id.")
    parser.add_argument("--gpu_ids",     dest="gpu_ids", nargs='+', type=int, help="GPU ids to use.")
    parser.add_argument("--n_parallel",  dest="n_parallel", type=int, default=1, help="parallel jobs per GPU")

    args = parser.parse_args()
    n_par = int(args.n_parallel)
    experiment = args.experiment
    max_time = 60. * 60. * 24. * 30.

    print("\n\n\nStart Batch Testing:")

    # Determine the conda environment
    env_path = sys.exec_prefix
    python_interpreter = path.join(env_path, "bin", "python")
    print(f"{'Python:':20} {python_interpreter}")

    # Setup experiment script:
    experiment_script = Path(experiment)
    assert experiment_script.is_file() and experiment_script.suffix == ".py"
    print(f"{'Script:':20} {experiment_script}")

    # TODO: Let choose specific cuda device

    args.gpu_ids = args.gpu_ids or [-1]  # if no GPU IDs are given, default to CPU.
    n_devices = len(args.gpu_ids)
    n_parallel = int(n_par * n_devices)
    # available_devices = [f'cuda:{i}' for i in np.mod(np.arange(0, n_parallel), n_devices).tolist()]
    # available_devices = list(itertools.chain(*[f"{g}" for g in args.gpu_ids] * n_par))
    available_devices = [str(g) for g in args.gpu_ids] * n_par

    keys, values = zip(*eval("".join(args.config)).items())
    flag_dicts = [dict(zip(keys, x)) for x in itertools.product(*values)]

    n_total = len(flag_dicts)

    cmds_all, cmds_waiting, cmds_running, cmds_finished = [], [], [], []
    # for i, (conf_file_name, seed) in enumerate(zip(config_files_list, seed_id_list)):
    for i, flags in enumerate(flag_dicts):

        now = datetime.now()
        job_name = f"{now:%Y_%m_%d_%H_%M}_seed=?_{socket.gethostname()}"
        stdout = "" #todo path.join(logs_dir, job_name+".out")
        stderr = "" #todo path.join(logs_dir, job_name+".err.txt")

        cmds_all.append({
            'job_idx': i,
            'n_total': n_total,
            'cmd': [
                python_interpreter, '-u', str(experiment_script),
                *map(str , list(itertools.chain(*flags.items())))
                # *list(itertools.chain(*[(a,f"'{b}'")  for a,b in flags.items()]))
            ],
            'process': None,
            'stdout': stdout,
            'stderr': stderr,
            'gpu': None,
            'start_time': None,
            'finish_time': None,
            'state': 'Waiting',
            'return_value': None
        })
    print(f'\nCommands to run ({len(cmds_all)}): ')
    for i,c in enumerate(cmds_all) :
        print(f"\t#{i}: {' '.join(c['cmd'])}")

    print("\nDeploy Trials:", flush=True)

    print_processes(cmds_all)
    last_print = datetime.now()
    cmds_waiting = list(reversed(cmds_all.copy()))

    try:
        # while there are still jobs to be run
        n_cmds_to_run = len(cmds_waiting)

        while n_cmds_to_run > len(cmds_finished):

            # while there are free cuda resources
            while len(available_devices) > 0 and len(cmds_waiting) > 0:

                time.sleep(1.)

                device = available_devices.pop()
                cmd = cmds_waiting.pop()
                # stdout = open(cmd['stdout'], 'w+')
                # stderr = open(cmd['stderr'], 'w+')

                cmd['start_time'] = datetime.now().timestamp()
                cmd['gpu'] = device
                # cmd['stdout'] = stdout
                # cmd['stderr'] = stderr

                command = cmd['cmd'] + ["-gpu", device]
                cmd['process'] = subprocess.Popen(command, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
                cmd['state'] = 'Running'

                cmds_running.append(cmd)

                time.sleep(.5)

            # Check for finished processes:
            for cmd_i in cmds_running:

                p_i = cmd_i['process']
                return_value = p_i.poll()
                cmd_i['return_value'] = return_value
                runtime = (datetime.now() - datetime.fromtimestamp(cmd_i['start_time'])).total_seconds()

                if return_value is not None or runtime > max_time:
                    if return_value is None:
                        # Process is overdue:
                        p_i.terminate()
                        time.sleep(10.)
                        p_i.kill()
                        cmd_i['state'] = 'Killed'
                    elif return_value == 0:
                        cmd_i['state'] = 'Finished'
                    else:
                        cmd_i['state'] = 'Died'

                    # Open device id:
                    available_devices.append(cmd_i['gpu'])
                    cmd_i['finish_time'] = datetime.now().timestamp()
                    # cmd_i['stdout'].close()
                    # cmd_i['stderr'].close()

                    cmds_finished.append(cmd_i)

            # Clean up running list
            cmds_running = [c for c in cmds_running if c['state'] is 'Running']
            time_now = datetime.now()

            if (time_now - last_print).total_seconds() > .5:
                print_processes(cmds_all, clear=True)
                last_print = time_now
            time.sleep(1)

        print_processes(cmds_all, clear=True)

    except Exception as e:
        print_processes(cmds_all)
        print(e)

    print("Experiment Finished", flush=True)








